# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Calculate the ground state of a 2d quantum Ising model with local terms added
using CustomSiteTerm.
=======================================================

Simple example on how to set up local terms with CustomSiteTerm
"""

# pylint: disable=invalid-name

import qtealeaves as qtl
from qtealeaves import modeling, operators

###############################################################################
# We prefer to use a main method to avoid the automatic run when imported


def main(tn_type=5, output_folder=None):
    """
    Main method for the ground state simulation of a
    quantum Ising model in 2D.

    **Arguments**

    tn_type : int, optional
        Choose 5 for TTN, 6 for MPS.
        Default to 5.

    output_folder : str | None, optional
        Output folder. Default to None.
    """

###############################################################################
# Defining the 2d model.
    dim = 2
    model_name = lambda params: "CustomSite_ex_2d_QIsing_g%2.4f" % (params["g"])
    model = modeling.QuantumModel(dim, "L", name=model_name)
    model += modeling.TwoBodyTerm2D(
        ["sx", "sx"], [0,1], strength="J", prefactor=-1, has_obc=True
    )
    # Here, I add a local term on the site with coordinates [0,0].
    # Being the sytem 2-D, the nested list is required also for local terms.
    model += modeling.CustomSiteTerm("sz",[[0,0]],strength="loc_sz")
    # Here, I add a three-body term on the site with coordinates [0,0], [0,1], [1,0].
    # Being the sytem 2-D, the nested list is required.
    model += modeling.CustomSiteTerm(["sz","sx","sx"],
                                     [[0,0],[0,1],[1,0]],
                                     strength="3_body")
    # Here, I add a local term using a callable that places the term at the center of the lattice.
    local_call = lambda params: [[int(params["L"]/2),int(params["L"]/2)]]
    model += modeling.CustomSiteTerm("sz",local_call, strength="loc_sz")

    # Here, I add an interaction term applying the terms at the four corners of the lattice.
    # I use a callable that takes the size of the system from params.
    interaction_call =  lambda params: [[0,0],
                                        [params["L"]-1,0],
                                        [0,params["L"]-1],
                                        [params["L"]-1,params["L"]-1]]
    model += modeling.CustomSiteTerm(["sz","sz","sz","sz"],
                                     interaction_call,
                                     strength="4_body")
    my_ops = operators.TNSpin12Operators()
###############################################################################
# We define **parametric** I/O folder to keep the results more ordered. As you
# see, they are parametrized through the size of the chain, i.e. the number
# of physical sites of the Tensor network
    if output_folder is None:
        output_folder = lambda params: "CustomSite_ex_2d_QIsing_L%d" % (params["L"])

###############################################################################
# We define the convergence parameters and the observables: they are really
# important:
#
# - The convergence parameters ensure we have a relaiable result. See
#   :docs:`/../chapters/convergence` for further informations about them.
# - The observables ensure we are measuring (and storing) something at the end
#   of the simulation! See :docs:`/../chapters/measurements` for further
#   informations about the available observables.

    my_conv = qtl.convergence_parameters.TNConvergenceParameters(
        max_iter=5, max_bond_dimension=16
    )
    my_obs = qtl.observables.TNObservables()

###############################################################################
# Define the simulation instance

    simulation = qtl.QuantumGreenTeaSimulation(
        model,
        my_ops,
        my_conv,
        my_obs,
        tn_type=tn_type,
        folder_name_output=output_folder,
        store_checkpoints=False,
    )

###############################################################################
# Define the parameters of the models: here we define the 'L' seen in the
# model definition at the beginning! Instead 'J' and 'g' are important model
# parameters. We also add terms that are added using the Models_by_interaction_list
# term.

    params = [
        {
            "L": 4,
            # model parameters
            "J": 1.0,
            "g": 0.5,
            "loc_sz" : 0.3,
            "3_body" : 0.4,
            "4_body" : 0.6
        }
    ]

###############################################################################
# We finally run the simulation.

    simulation.run(params, delete_existing_folder=True)

    for elem in params:
        tn_energy_0 = simulation.get_static_obs(elem)["energy"]
        print("Ground state energy", tn_energy_0)


    return


###############################################################################
# Run the code

if __name__ == "__main__":
    main()
